# -*- coding: utf-8 -*-
import csv
import os
import threading

import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

model_dir = 'tflog'
data_folder = 'tfdata\\'


def start_tensorboard_thread():

    def inner():
        os.chdir(os.getcwd())
        os.system('tensorboard --logdir=tflog')

    th = threading.Thread(target=inner)
    th.start()


def open_chrome_thread():

    def inner():
        os.system(r"C:\Users\Hs\AppData\Local\Google\Chrome\Application\chrome.exe http:\\imi-bes:6006")

    th = threading.Thread(target=inner)
    th.start()


def clearFile_log():
    clearFile(path=model_dir)
    print('clear file done')


def clearFile(path=None):
    for i in os.listdir(path):
        path_file = os.path.join(path, i)
        if os.path.isfile(path_file):
            os.remove(path_file)
        else:
            clearFile(path_file)


def load_csv(filename, x_dtype, y_dtype, y_column=-1):
    with gfile.Open(filename) as csv_file:
        data_file = csv.reader(csv_file)
        x0, y0 = [], []
        for row in data_file:
            y0.append([row.pop(y_column)])
            x0.append(np.asarray(row, dtype=x_dtype))
    x = np.array(x0, dtype=x_dtype)
    y = np.array(y0, dtype=y_dtype)
    print('x.shape = ' + str(x.shape) + '    y.shape = ' + str(y.shape))
    return x, y


def to_onehot(y, typeNum):
    rows = len(y)
    label1 = np.zeros([rows, typeNum])
    for i in range(rows):
        label1[i, int(y[i])] = 1
    return label1


def readData_csv(filename, splitRatio=1):
    x, y = load_csv(filename=filename, y_dtype=np.float32, x_dtype=np.float32)
    if splitRatio != 1:
        rows = len(y)
        splitIndex = int(rows * splitRatio)
        indices = np.random.permutation(rows)
        trainIndex, testIndex = indices[:splitIndex], indices[splitIndex:]
        x1 = x[trainIndex]
        y1 = y[trainIndex]
        x2 = x[testIndex]
        y2 = y[testIndex]
        return (x1, y1), (x2, y2)
    return (x, y)


def readData_csv_onehot(filename, typeNum, splitRatio=1):
    x, y = load_csv(filename=filename, y_dtype=np.float32, x_dtype=np.float32)
    y = to_onehot(y, typeNum)
    if splitRatio != 1:
        rows = len(y)
        splitIndex = int(rows * splitRatio)
        indices = np.random.permutation(rows)
        trainIndex, testIndex = indices[:splitIndex], indices[splitIndex:]
        x1 = x[trainIndex]
        y1 = y[trainIndex]
        x2 = x[testIndex]
        y2 = y[testIndex]
        return (x1, y1), (x2, y2)
    return (x, y)


def _int64_feature(value):
    return tf.compat.v1.train.Feature(int64_list=tf.compat.v1.train.Int64List(value=value))


def _bytes_feature(value):
    return tf.compat.v1.train.Feature(bytes_list=tf.compat.v1.train.BytesList(value=value))


def _float_feature(value):
    return tf.compat.v1.train.Feature(float_list=tf.compat.v1.train.FloatList(value=value))


def convert_to_tfrecords(xs, ys, filename):
    rows = len(ys)
    print(rows)
    writer = tf.compat.v1.io.TFRecordWriter(filename)
    for row in range(rows):
        example = tf.compat.v1.train.Example(features=tf.train.Features(feature={'x': _float_feature(xs[row]), 'y': _float_feature(ys[row])}))
        writer.write(example.SerializeToString())
    writer.close()


class DataProcess():

    def __init__(self, xSize, trainName_tf, testName_tf, trainName_csv, testName_csv, typeNum=1):
        self.xSize = xSize
        self.trainName_tf = trainName_tf
        self.testName_tf = testName_tf
        self.trainName_csv = trainName_csv
        self.testName_csv = testName_csv
        self.typeNum = typeNum

    def input_train(self):
        return self.dataset_input_train(data_folder + self.trainName_tf, batch=64, buffer_size=12800)

    def input_test(self):
        return self.dataset_input_test(data_folder + self.testName_tf, batch=128, buffer_size=12800)

    def serving_input_receiver_fn(self):
        input_tensors = tf.compat.v1.placeholder(dtype=tf.string, shape=[None], name='input_tensors')
        receiver_tensors = {"receiver_tensors": input_tensors}
        feature_spec = {"x": tf.compat.v1.io.FixedLenFeature([self.xSize], tf.float32)}
        features = tf.compat.v1.io.parse_example(serialized=input_tensors, features=feature_spec)
        return tf.compat.v1.estimator.export.ServingInputReceiver(features, receiver_tensors)

    def parser(self, record):
        keys_to_features = {'x': tf.compat.v1.io.FixedLenFeature([self.xSize], tf.float32), 'y': tf.compat.v1.io.FixedLenFeature([1], tf.float32)}
        parsed = tf.compat.v1.io.parse_single_example(serialized=record, features=keys_to_features)
        xData = tf.compat.v1.cast(parsed['x'], tf.float32)
        yData = tf.compat.v1.cast(parsed['y'], tf.float32)
        return {"x": xData}, {"y":yData}
    
    def parser2(self, record):
        keys_to_features = {'x': tf.io.FixedLenFeature([self.xSize], tf.float32), 'y': tf.io.FixedLenFeature([1], tf.float32)}
        parsed = tf.io.parse_single_example(serialized=record, features=keys_to_features)
        xData = tf.cast(parsed['x'], tf.float32)
        yData = tf.cast(parsed['y'], tf.float32)
        return tf.reshape(xData, shape=[224, 224, 1]), yData

    def dataset_input_train(self, filename, batch, buffer_size):
        dataset = tf.compat.v1.data.TFRecordDataset(filename)
        dataset = dataset.map(self.parser)
        dataset = dataset.shuffle(buffer_size=buffer_size)
        dataset = dataset.batch(batch)
        dataset = dataset.repeat()
        iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
        features, labels = iterator.get_next()
        return features, labels

    def dataset_input_test(self, filename, batch, buffer_size):
        dataset = tf.compat.v1.data.TFRecordDataset(filename)
        dataset = dataset.map(self.parser)
        dataset = dataset.shuffle(buffer_size=buffer_size)
        dataset = dataset.batch(batch)
        iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
        features, labels = iterator.get_next()
        return features, labels
    
    def get_dataset_train(self, batch=64, buffer_size=128000):
        dataset = tf.data.TFRecordDataset(data_folder + self.trainName_tf)
        dataset = dataset.map(self.parser2)
        dataset = dataset.shuffle(buffer_size=buffer_size)
        dataset = dataset.batch(batch)
        dataset = dataset.repeat(4)
        return dataset

    def get_dataset_test(self, batch=64, buffer_size=128000):
        dataset = tf.data.TFRecordDataset(data_folder + self.testName_tf)
        dataset = dataset.map(self.parser2)
        dataset = dataset.shuffle(buffer_size=buffer_size)
        dataset = dataset.batch(batch)
        return dataset

    def convert_train_split_onehot(self):
        print('reading...')
        trainData, testData = readData_csv_onehot(filename=data_folder + self.trainName_csv, typeNum=self.typeNum, splitRatio=0.8)
        print('writing...')
        convert_to_tfrecords(trainData[0], trainData[1], data_folder + self.trainName_tf)
        convert_to_tfrecords(testData[0], testData[1], data_folder + self.testName_tf)
        print('done')

    def convertData_train_split(self):
        print('reading...')
        trainData, testData = readData_csv(filename=data_folder + self.trainName_csv, splitRatio=0.8)
        print('writing...')
        convert_to_tfrecords(trainData[0], trainData[1], data_folder + self.trainName_tf)
        convert_to_tfrecords(testData[0], testData[1], data_folder + self.testName_tf)
        print('done')

    def convertData_train(self):
        print('reading...')
        trainData = readData_csv(filename=data_folder + self.trainName_csv)
        print('writing...')
        convert_to_tfrecords(trainData[0], trainData[1], data_folder + self.trainName_tf)
        print('done')

    def convertData_test(self):
        print('reading...')
        testData = readData_csv(filename=data_folder + self.testName_csv)
        print('writing...')
        convert_to_tfrecords(testData[0], testData[1], data_folder + self.testName_tf)
        print('done')


dataProcess_lwh_length = DataProcess(xSize=64 * 64,
                                     trainName_tf='lwh_train_length.tf',
                                     testName_tf='lwh_test_length.tf',
                                     trainName_csv='lwh_train_length.csv',
                                     testName_csv='lwh_test_length.csv')

dataProcess_tic_length = DataProcess(xSize=48 * 48,
                                     trainName_tf='tic_train_length.tf',
                                     testName_tf='tic_test_length.tf',
                                     trainName_csv='tic_train_length.csv',
                                     testName_csv='tic_test_length.csv')

dataProcess_tic_type = DataProcess(xSize=224 * 224,
                                   trainName_tf='tic_train_type.tf',
                                   testName_tf='tic_test_type.tf',
                                   trainName_csv='tic_train_type.csv',
                                   testName_csv='tic_test_type.csv',
                                   typeNum=17)